{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a126c44a-e938-414d-8b60-a4d9dafa6e24",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7d55b6aa-32ea-4a5b-b11e-dd72d67f0ad5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/data/work/jiaqi/targetdiff\n"
     ]
    }
   ],
   "source": [
    "cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "add16081-1742-4c9c-8a0b-be0f5c961ad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import torch\n",
    "from datasets import get_dataset\n",
    "import utils.misc as misc\n",
    "import utils.transforms as trans\n",
    "from torch_geometric.transforms import Compose\n",
    "from tqdm.auto import tqdm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "409453c5-b142-43c8-ba72-ff4273486839",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.molopt_score_model import ScorePosNet3D\n",
    "from scripts.likelihood_est_diffusion import data_likelihood_estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5b4b8010-3485-4392-86ee-3d7498891bdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sampling_config = 'configs/sampling.yml'\n",
    "sampling_config = misc.load_config(sampling_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6bb57999-ff4d-457f-9632-fcd25905f988",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load checkpoint\n",
    "device = 'cuda:0'\n",
    "ckpt = torch.load(sampling_config.model.checkpoint, map_location=device)\n",
    "\n",
    "# Transforms\n",
    "protein_featurizer = trans.FeaturizeProteinAtom()\n",
    "ligand_atom_mode = ckpt['config'].data.transform.ligand_atom_mode\n",
    "ligand_featurizer = trans.FeaturizeLigandAtom(ligand_atom_mode)\n",
    "transform = Compose([\n",
    "    protein_featurizer,\n",
    "    ligand_featurizer,\n",
    "    trans.FeaturizeLigandBond(),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0b8348e4-4b92-478d-b8d3-86eaf965973d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully load the model! ./pretrained_models/pretrained_diffusion.pt\n"
     ]
    }
   ],
   "source": [
    "# Load model\n",
    "model = ScorePosNet3D(\n",
    "    ckpt['config'].model,\n",
    "    protein_atom_feature_dim=protein_featurizer.feature_dim,\n",
    "    ligand_atom_feature_dim=ligand_featurizer.feature_dim\n",
    ").to(device)\n",
    "model.load_state_dict(ckpt['model'], strict=True)\n",
    "print(f'Successfully load the model! {sampling_config.model.checkpoint}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bc49da98-2ecc-4fe5-aec5-ddd1133d05df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.data import PDBProtein, parse_sdf_file\n",
    "from datasets.pl_data import ProteinLigandData, torchify_dict\n",
    "from torch.utils.data import Dataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.data import Batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4b29b9cc-8bb1-4ee2-a575-281b90239493",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_data(pdb_path, ligand_path, transform, radius=10, pocket=False):\n",
    "    # ligand_dict = parse_sdf_file_mol(ligand_path, heavy_only=False)\n",
    "    ligand_dict = parse_sdf_file(ligand_path)\n",
    "    if not pocket:\n",
    "        protein = PDBProtein(pdb_path)\n",
    "        pdb_block_pocket = protein.residues_to_pdb_block(\n",
    "            protein.query_residues_ligand(ligand_dict, radius)\n",
    "        )\n",
    "        pocket_dict = PDBProtein(pdb_block_pocket).to_dict_atom()\n",
    "    else:\n",
    "        pocket_dict = PDBProtein(pdb_path).to_dict_atom()\n",
    "\n",
    "    data = ProteinLigandData.from_protein_ligand_dicts(\n",
    "        protein_dict=torchify_dict(pocket_dict),\n",
    "        ligand_dict=torchify_dict(ligand_dict),\n",
    "    )\n",
    "    data.protein_filename = pdb_path\n",
    "    data.ligand_filename = ligand_path\n",
    "    # data.y = torch.tensor(float(pka))\n",
    "    # data.kind = torch.tensor(2)  # kd\n",
    "    # data.id = idx\n",
    "    assert data.protein_pos.size(0) > 0\n",
    "    if transform is not None:\n",
    "        data = transform(data)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7280dc63-7b98-4771-af61-13abbd0b7262",
   "metadata": {},
   "outputs": [],
   "source": [
    "class InferenceDataset(Dataset):\n",
    "    def __init__(self, data_list):\n",
    "        super().__init__()\n",
    "        self.data_list = data_list\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        data = self.data_list[idx]\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "52342244-21b3-473c-b9c0-575975148a06",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = convert_data('examples/3ug2_protein.pdb', \n",
    "                         'examples/3ug2_ligand.sdf', transform)\n",
    "# test_data.kind = KMAP[args.kind]\n",
    "test_set = InferenceDataset([test_data])\n",
    "test_loader = DataLoader(test_set, batch_size=1, shuffle=False, follow_batch=['protein_element', 'ligand_element'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "82b452bc-5ac3-4ee1-8907-2dcd8925abb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(test_loader)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2947021d-a280-4699-9c95-c6ede8c98cb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = model.fetch_embedding(\n",
    "            protein_pos=batch.protein_pos,\n",
    "            protein_v=batch.protein_atom_feature.float(),\n",
    "            batch_protein=batch.protein_element_batch,\n",
    "\n",
    "            ligand_pos=batch.ligand_pos,\n",
    "            ligand_v=batch.ligand_atom_feature_full,\n",
    "            batch_ligand=batch.ligand_element_batch,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f245616a-403a-4006-87ac-de89a6ad0ef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load linear model\n",
    "with open('pretrained_models/pk_reg_para.pkl', 'rb') as f:\n",
    "    lmodel = pickle.load(f) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "661ce15e-9485-4429-85ff-2ed43de5aa67",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_ligand_h = np.array([preds['final_ligand_h'].cpu().numpy().mean(0)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "80444932-f508-4b68-8d00-80becf7ffb28",
   "metadata": {},
   "outputs": [],
   "source": [
    "pka = lmodel.predict(final_ligand_h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b05a8190-d260-4d3d-a614-96b2a7bcc88e",
   "metadata": {},
   "outputs": [],
   "source": [
    "affinity = np.power(10, -pka)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "340eebb5-0307-4cd1-9461-b932d59ec90c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.0119374e-09], dtype=float32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "affinity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9597a9b5-84b0-4892-b12a-cf80a5fb1681",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Pocket2Mol",
   "language": "python",
   "name": "pocket2mol"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
